{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import math\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import librosa\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import stats\n",
    "from sklearn import linear_model\n",
    "from collections import Counter\n",
    "\n",
    "# load other modules --> repo root path\n",
    "sys.path.insert(0, \"../\")\n",
    "\n",
    "from utils import text\n",
    "from utils import audio\n",
    "from utils.logging import Logger\n",
    "from params.params import Params as hp\n",
    "from dataset.dataset import TextToSpeechDataset, TextToSpeechDatasetCollection, TextToSpeechCollate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load dataset and prepare data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp.sample_rate = 22050\n",
    "hp.languages = [\"german\", \"dutch\", \"french\", \"greek\", \"japanese\", \"russian\", \"chinese\", \"finnish\", \"german\", \"hungarian\", \"spanish\"]\n",
    "\n",
    "common = ' '\n",
    "greece = 'άέήίαβγδεζηθικλμνξοπρςíστυφχψωόύώ'\n",
    "russian = 'абвгдежзийклмнопрстуфхцчшщъыьэюяё'\n",
    "\n",
    "asciis = 'abcdefghijklmnopqrstuvwxyz'\n",
    "chinese = 'ōǎǐǒàáǔèéìíūòóùúüāēěī'             \n",
    "finnish = 'éöä'                               \n",
    "german = 'ßàäéöü'                             \n",
    "hungarian = 'őáéíóöűúü'\n",
    "french = 'àâçèéêíôùû'\n",
    "spanish = 'áèéíñóöúü'\n",
    "\n",
    "hp.characters = ''.join(set(common + greece + russian + asciis + chinese + finnish + german + hungarian + french + spanish))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "hp.predict_linear = True\n",
    "hp.num_fft = 1102\n",
    "metafile = \"all_reduced.txt\"\n",
    "dataset_root = \"../data/css10\" \n",
    "data = TextToSpeechDataset(os.path.join(dataset_root, metafile), dataset_root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "durations = []\n",
    "lengths = []\n",
    "num_words = []\n",
    "lengths_phon = []\n",
    "languages = []\n",
    "freq_chars = {l: Counter() for l in hp.languages}\n",
    "freq_phon = {l: Counter() for l in hp.languages}\n",
    "\n",
    "Logger.progress(0, prefix='Computing stats:')\n",
    "for i, item in enumerate(data.items): \n",
    "    \n",
    "    languages.append(hp.languages[item['language']])\n",
    "    \n",
    "    audio_path = item['audio']\n",
    "    full_audio_path = os.path.join(dataset_root, audio_path)\n",
    "    waveform = audio.load(full_audio_path)\n",
    "    durations.append(audio.duration(waveform))\n",
    "        \n",
    "    utterance = text.to_text(item['text'], use_phonemes=False)\n",
    "    clear_utterance = text.remove_punctuation(utterance)\n",
    "    clear_words = clear_utterance.split()    \n",
    "    lengths.append(len(utterance))\n",
    "    num_words.append(len(clear_words))\n",
    "    \n",
    "    clear_utterance = clear_utterance.replace(' ', '')\n",
    "    freq_chars[hp.languages[item['language']]].update(clear_utterance)\n",
    "     \n",
    "    utterance_pho = text.to_text(item['phonemes'], use_phonemes=True)\n",
    "    lengths_phon.append(len(utterance_pho))\n",
    "    utterance_pho = utterance_pho.replace(' ', '')\n",
    "    utterance_pho = text.remove_punctuation(utterance_pho)\n",
    "    freq_phon[hp.languages[item['language']]].update(utterance_pho)\n",
    "    \n",
    "    Logger.progress((i + 1) / len(data.items), prefix='Computing stats:')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Item from data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "item = data.items[0]\n",
    "\n",
    "audio_path = item['audio']\n",
    "full_audio_path = os.path.join(dataset_root, audio_path)\n",
    "waveform = audio.load(full_audio_path)\n",
    "\n",
    "print(item['text'])\n",
    "print(text.to_text(item['text'], False))\n",
    "print(text.to_text(item['phonemes'], True))\n",
    "print(audio.duration(waveform))\n",
    "\n",
    "melspec = audio.mel_spectrogram(waveform)\n",
    "spec = audio.spectrogram(waveform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(rc={'figure.figsize':(16,4)})\n",
    "sns.set_style(\"white\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame({'Words'      :pd.Series(num_words, dtype='int'),\n",
    "                   'Length'     :pd.Series(lengths, dtype='int'),\n",
    "                   'Duration'   :pd.Series(durations, dtype='float'),\n",
    "                   'LengthPhon' :pd.Series(lengths_phon, dtype='int'),\n",
    "                   'Language'   :pd.Series(languages, dtype='category')},\n",
    "                   columns=['Words', 'Length', 'Duration', 'LengthPhon', 'Language'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(df))\n",
    "df = df[df['Length'] < 190]\n",
    "print(len(df))\n",
    "df = df[df['Duration'] < 10.1]\n",
    "print(len(df))\n",
    "df = df[df['Duration'] > 0.5]\n",
    "print(len(df))\n",
    "df = df[df['Length'] > 2]\n",
    "print(len(df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "total = pd.DataFrame()\n",
    "for name, group in df.groupby('Language'):\n",
    "    #group_mean = df.groupby(\"Length\").mean()\n",
    "    #group_mean = group_mean.loc[df['Length']].reset_index()[\"Duration\"]\n",
    "    \n",
    "    lr = linear_model.LinearRegression().fit(group['Length'].values.reshape(-1,1), group['Duration'].values.reshape(-1,1))\n",
    "    group_mean = lr.predict(np.array(group['Length']).reshape(-1,1)).squeeze(-1)\n",
    "\n",
    "    group_std = group.groupby(\"Length\").std()\n",
    "    group_std = group_std.loc[group['Length']][\"Duration\"]\n",
    "    group_std.index = group.index\n",
    "    \n",
    "    m = group[(abs(group['Duration'] - group_mean) < np.log10(group['Length'])+1)] # & (abs(group['Duration'] - group_mean) - 3 * group_std < 0)\n",
    "    total = pd.concat([m, total])\n",
    "    \n",
    "df = total\n",
    "print(len(df))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# out_file = \"idxes_clean.txt\"\n",
    "# with open(os.path.join(dataset_root, out_file), mode='w') as f:\n",
    "#     for i in sorted(df.index):\n",
    "#         print(f'{i}'.zfill(6), file=f)\n",
    "\n",
    "# join -t '|' idxes_clean.txt a.txt > b.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Duration distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, group in df.groupby('Language'):\n",
    "    print(f'{name}:\\t{sum(group[\"Duration\"])/3600}')\n",
    "print(f'Total:\\t{sum(df[\"Duration\"])/3600}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for name, group in df.groupby('Language'):\n",
    "    print(f'Min duration: {min(group[\"Duration\"])}')\n",
    "    print(f'Max duration: {max(group[\"Duration\"])}')\n",
    "    ax = sns.distplot(group['Duration'], hist=True, rug=False, fit=stats.norm, color=\"c\", kde_kws={\"color\": \"b\", \"lw\": 3}, fit_kws={\"color\": \"r\", \"lw\": 3})\n",
    "    ax.set(xlabel='Duration (s)', title=name);\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Length distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, group in df.groupby('Language'):\n",
    "    print(f'Min length: {min(group[\"Length\"])}')\n",
    "    print(f'Max length: {max(group[\"Length\"])}')\n",
    "    ax = sns.distplot(group['Length'], kde=True, rug=False, fit=stats.norm, color=\"c\", kde_kws={\"color\": \"b\", \"lw\": 3}, fit_kws={\"color\": \"r\", \"lw\": 3})\n",
    "    ax.set(xlabel='Length', title=name);\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Word count distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ax = sns.distplot(df['Words'], kde=True, rug=False, fit=stats.norm, color=\"c\", kde_kws={\"color\": \"b\", \"lw\": 3}, fit_kws={\"color\": \"r\", \"lw\": 3})\n",
    "ax.set(xlabel='Word count');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Phonemized length distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.distplot(df['LengthPhon'], kde=True, rug=False, fit=stats.norm, color=\"c\", kde_kws={\"color\": \"b\", \"lw\": 3}, fit_kws={\"color\": \"r\", \"lw\": 3})\n",
    "ax.set(xlabel='Phonemized length');"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Duration vs Length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for name, group in df.groupby('Language'):\n",
    "    ax = sns.jointplot(group['Length'], group['Duration'], kind=\"hex\", space=0, color=\"b\")\n",
    "    ax.fig.set_figwidth(7)\n",
    "    ax.ax_joint.set(xlabel='Length', ylabel='Duration', title=name);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.relplot(x=\"Length\", y=\"Duration\", kind=\"line\", ci=\"sd\", linewidth=3, data=df)\n",
    "ax.fig.set_figwidth(15)\n",
    "ax.fig.set_figheight(4)\n",
    "ax.set(yticks=np.arange(round(min(df['Duration'])), max(df['Duration']) + 1,2))\n",
    "plt.ylim(min(df['Duration']) - 1, max(df['Duration']) + 1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"white\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Duration vs Phonemized length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.jointplot(df['LengthPhon'], df['Duration'], kind=\"hex\", space=0, color=\"b\")\n",
    "ax.fig.set_figwidth(7)\n",
    "ax.ax_joint.set(xlabel='Word count', ylabel='Duration');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"whitegrid\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ax = sns.relplot(x=\"LengthPhon\", y=\"Duration\", kind=\"line\", ci=\"sd\", linewidth=3, data=df)\n",
    "ax.fig.set_figwidth(15)\n",
    "ax.fig.set_figheight(4)\n",
    "ax.set(yticks=np.arange(round(min(df['Duration'])), max(df['Duration']) + 1, 2))\n",
    "plt.ylim(min(durations) - 1, max(df['Duration']) + 1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set_style(\"white\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Phonemes distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "symbols_phon = hp.phonemes.replace(' ', '')\n",
    "symbols_phon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "for k, v in freq_phon.items():\n",
    "    sk = sorted(v.keys())\n",
    "    g = sns.barplot(x=list(sk), y=[v[x] for x in sk]).set_title(k)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "total = Counter()\n",
    "for k, v in freq_phon.items():\n",
    "    total.update(v)\n",
    "sns.barplot(x=list(symbols_phon), y=[total[x] for x in symbols_phon]).set_title(\"Total\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "''.join(list(total.keys()))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
